# Copyright 2015 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Examples of Python implementations of the stock.proto Stock service."""

from grpc.framework.common import cardinality
from grpc.framework.foundation import abandonment
from grpc.framework.foundation import stream
from tests.unit.framework.common import test_constants
from tests.unit.framework.interfaces.face import _service
from tests.unit._junkdrawer import stock_pb2

_STOCK_GROUP_NAME = 'Stock'
_SYMBOL_FORMAT = 'test symbol:%03d'

# A test-appropriate security-pricing function. :-P
_price = lambda symbol_name: float(hash(symbol_name) % 4096)


def _get_last_trade_price(stock_request, stock_reply_callback, control, active):
    """A unary-request, unary-response test method."""
    control.control()
    if active():
        stock_reply_callback(
            stock_pb2.StockReply(
                symbol=stock_request.symbol,
                price=_price(stock_request.symbol)))
    else:
        raise abandonment.Abandoned()


def _get_last_trade_price_multiple(stock_reply_consumer, control, active):
    """A stream-request, stream-response test method."""

    def stock_reply_for_stock_request(stock_request):
        control.control()
        if active():
            return stock_pb2.StockReply(
                symbol=stock_request.symbol, price=_price(stock_request.symbol))
        else:
            raise abandonment.Abandoned()

    class StockRequestConsumer(stream.Consumer):

        def consume(self, stock_request):
            stock_reply_consumer.consume(
                stock_reply_for_stock_request(stock_request))

        def terminate(self):
            control.control()
            stock_reply_consumer.terminate()

        def consume_and_terminate(self, stock_request):
            stock_reply_consumer.consume_and_terminate(
                stock_reply_for_stock_request(stock_request))

    return StockRequestConsumer()


def _watch_future_trades(stock_request, stock_reply_consumer, control, active):
    """A unary-request, stream-response test method."""
    base_price = _price(stock_request.symbol)
    for index in range(stock_request.num_trades_to_watch):
        control.control()
        if active():
            stock_reply_consumer.consume(
                stock_pb2.StockReply(
                    symbol=stock_request.symbol, price=base_price + index))
        else:
            raise abandonment.Abandoned()
    stock_reply_consumer.terminate()


def _get_highest_trade_price(stock_reply_callback, control, active):
    """A stream-request, unary-response test method."""

    class StockRequestConsumer(stream.Consumer):
        """Keeps an ongoing record of the most valuable symbol yet consumed."""

        def __init__(self):
            self._symbol = None
            self._price = None

        def consume(self, stock_request):
            control.control()
            if active():
                if self._price is None:
                    self._symbol = stock_request.symbol
                    self._price = _price(stock_request.symbol)
                else:
                    candidate_price = _price(stock_request.symbol)
                    if self._price < candidate_price:
                        self._symbol = stock_request.symbol
                        self._price = candidate_price

        def terminate(self):
            control.control()
            if active():
                if self._symbol is None:
                    raise ValueError()
                else:
                    stock_reply_callback(
                        stock_pb2.StockReply(
                            symbol=self._symbol, price=self._price))
                    self._symbol = None
                    self._price = None

        def consume_and_terminate(self, stock_request):
            control.control()
            if active():
                if self._price is None:
                    stock_reply_callback(
                        stock_pb2.StockReply(
                            symbol=stock_request.symbol,
                            price=_price(stock_request.symbol)))
                else:
                    candidate_price = _price(stock_request.symbol)
                    if self._price < candidate_price:
                        stock_reply_callback(
                            stock_pb2.StockReply(
                                symbol=stock_request.symbol,
                                price=candidate_price))
                    else:
                        stock_reply_callback(
                            stock_pb2.StockReply(
                                symbol=self._symbol, price=self._price))

                self._symbol = None
                self._price = None

    return StockRequestConsumer()


class GetLastTradePrice(_service.UnaryUnaryTestMethodImplementation):
    """GetLastTradePrice for use in tests."""

    def group(self):
        return _STOCK_GROUP_NAME

    def name(self):
        return 'GetLastTradePrice'

    def cardinality(self):
        return cardinality.Cardinality.UNARY_UNARY

    def request_class(self):
        return stock_pb2.StockRequest

    def response_class(self):
        return stock_pb2.StockReply

    def serialize_request(self, request):
        return request.SerializeToString()

    def deserialize_request(self, serialized_request):
        return stock_pb2.StockRequest.FromString(serialized_request)

    def serialize_response(self, response):
        return response.SerializeToString()

    def deserialize_response(self, serialized_response):
        return stock_pb2.StockReply.FromString(serialized_response)

    def service(self, request, response_callback, context, control):
        _get_last_trade_price(request, response_callback, control,
                              context.is_active)


class GetLastTradePriceMessages(_service.UnaryUnaryTestMessages):

    def __init__(self):
        self._index = 0

    def request(self):
        symbol = _SYMBOL_FORMAT % self._index
        self._index += 1
        return stock_pb2.StockRequest(symbol=symbol)

    def verify(self, request, response, test_case):
        test_case.assertEqual(request.symbol, response.symbol)
        test_case.assertEqual(_price(request.symbol), response.price)


class GetLastTradePriceMultiple(_service.StreamStreamTestMethodImplementation):
    """GetLastTradePriceMultiple for use in tests."""

    def group(self):
        return _STOCK_GROUP_NAME

    def name(self):
        return 'GetLastTradePriceMultiple'

    def cardinality(self):
        return cardinality.Cardinality.STREAM_STREAM

    def request_class(self):
        return stock_pb2.StockRequest

    def response_class(self):
        return stock_pb2.StockReply

    def serialize_request(self, request):
        return request.SerializeToString()

    def deserialize_request(self, serialized_request):
        return stock_pb2.StockRequest.FromString(serialized_request)

    def serialize_response(self, response):
        return response.SerializeToString()

    def deserialize_response(self, serialized_response):
        return stock_pb2.StockReply.FromString(serialized_response)

    def service(self, response_consumer, context, control):
        return _get_last_trade_price_multiple(response_consumer, control,
                                              context.is_active)


class GetLastTradePriceMultipleMessages(_service.StreamStreamTestMessages):
    """Pairs of message streams for use with GetLastTradePriceMultiple."""

    def __init__(self):
        self._index = 0

    def requests(self):
        base_index = self._index
        self._index += 1
        return [
            stock_pb2.StockRequest(symbol=_SYMBOL_FORMAT % (base_index + index))
            for index in range(test_constants.STREAM_LENGTH)
        ]

    def verify(self, requests, responses, test_case):
        test_case.assertEqual(len(requests), len(responses))
        for stock_request, stock_reply in zip(requests, responses):
            test_case.assertEqual(stock_request.symbol, stock_reply.symbol)
            test_case.assertEqual(
                _price(stock_request.symbol), stock_reply.price)


class WatchFutureTrades(_service.UnaryStreamTestMethodImplementation):
    """WatchFutureTrades for use in tests."""

    def group(self):
        return _STOCK_GROUP_NAME

    def name(self):
        return 'WatchFutureTrades'

    def cardinality(self):
        return cardinality.Cardinality.UNARY_STREAM

    def request_class(self):
        return stock_pb2.StockRequest

    def response_class(self):
        return stock_pb2.StockReply

    def serialize_request(self, request):
        return request.SerializeToString()

    def deserialize_request(self, serialized_request):
        return stock_pb2.StockRequest.FromString(serialized_request)

    def serialize_response(self, response):
        return response.SerializeToString()

    def deserialize_response(self, serialized_response):
        return stock_pb2.StockReply.FromString(serialized_response)

    def service(self, request, response_consumer, context, control):
        _watch_future_trades(request, response_consumer, control,
                             context.is_active)


class WatchFutureTradesMessages(_service.UnaryStreamTestMessages):
    """Pairs of a single request message and a sequence of response messages."""

    def __init__(self):
        self._index = 0

    def request(self):
        symbol = _SYMBOL_FORMAT % self._index
        self._index += 1
        return stock_pb2.StockRequest(
            symbol=symbol, num_trades_to_watch=test_constants.STREAM_LENGTH)

    def verify(self, request, responses, test_case):
        test_case.assertEqual(test_constants.STREAM_LENGTH, len(responses))
        base_price = _price(request.symbol)
        for index, response in enumerate(responses):
            test_case.assertEqual(base_price + index, response.price)


class GetHighestTradePrice(_service.StreamUnaryTestMethodImplementation):
    """GetHighestTradePrice for use in tests."""

    def group(self):
        return _STOCK_GROUP_NAME

    def name(self):
        return 'GetHighestTradePrice'

    def cardinality(self):
        return cardinality.Cardinality.STREAM_UNARY

    def request_class(self):
        return stock_pb2.StockRequest

    def response_class(self):
        return stock_pb2.StockReply

    def serialize_request(self, request):
        return request.SerializeToString()

    def deserialize_request(self, serialized_request):
        return stock_pb2.StockRequest.FromString(serialized_request)

    def serialize_response(self, response):
        return response.SerializeToString()

    def deserialize_response(self, serialized_response):
        return stock_pb2.StockReply.FromString(serialized_response)

    def service(self, response_callback, context, control):
        return _get_highest_trade_price(response_callback, control,
                                        context.is_active)


class GetHighestTradePriceMessages(_service.StreamUnaryTestMessages):

    def requests(self):
        return [
            stock_pb2.StockRequest(symbol=_SYMBOL_FORMAT % index)
            for index in range(test_constants.STREAM_LENGTH)
        ]

    def verify(self, requests, response, test_case):
        price = None
        symbol = None
        for stock_request in requests:
            current_symbol = stock_request.symbol
            current_price = _price(current_symbol)
            if price is None or price < current_price:
                price = current_price
                symbol = current_symbol
        test_case.assertEqual(price, response.price)
        test_case.assertEqual(symbol, response.symbol)


class StockTestService(_service.TestService):
    """A corpus of test data with one method of each RPC cardinality."""

    def unary_unary_scenarios(self):
        return {
            (_STOCK_GROUP_NAME, 'GetLastTradePrice'):
            (GetLastTradePrice(), [GetLastTradePriceMessages()]),
        }

    def unary_stream_scenarios(self):
        return {
            (_STOCK_GROUP_NAME, 'WatchFutureTrades'):
            (WatchFutureTrades(), [WatchFutureTradesMessages()]),
        }

    def stream_unary_scenarios(self):
        return {
            (_STOCK_GROUP_NAME, 'GetHighestTradePrice'):
            (GetHighestTradePrice(), [GetHighestTradePriceMessages()])
        }

    def stream_stream_scenarios(self):
        return {
            (_STOCK_GROUP_NAME, 'GetLastTradePriceMultiple'):
            (GetLastTradePriceMultiple(),
             [GetLastTradePriceMultipleMessages()]),
        }


STOCK_TEST_SERVICE = StockTestService()
